"""Instantiate pin cell Cells and Universes for core model."""

from math import sqrt

import numpy as np
import openmc
from openmc.model import subdivide

from nuscale.materials import mats
from nuscale.surfaces import surfs


def make_pin(name, surfaces, materials, grid=None):
    """Construct a pin cell Universes with radially layered Cells.

    Parameters
    ----------
    name: str
        The string name to assign to the Universe and each of its Cells
    surfaces: Iterable of openmc.ZCylinder
        The surfaces between which Cells are constructed
        to comprise the radially layered pin cell Universe.
    materials: Iterable of openmc.Material
        The Materials used within each radial layer. This collection
        must be one unit longer than the collection of surfaces.
    grid: str, optional
        The type of grid spacer to wrap around the pin cell universe.
        Accepted types include 'bottom' and 'intermediate'.

    Returns
    -------
    universe: openmc.Universe
        The pin cell Universe
    """

    universe = openmc.Universe(name=name)

    # Create cell for interior of innermost ZCylinder
    cell_name = name + ' (0)'
    cell = openmc.Cell(name=cell_name, fill=materials[0], region=-surfaces[0])
    universe.add_cell(cell)

    # Create cells between two ZCylinders
    for i, (mat, surf) in enumerate(zip(materials[1:-1], surfaces[:-1])):
        cell_name = name + ' ({})'.format(i+1)
        cell = openmc.Cell(name=cell_name)
        cell.fill = mat
        cell.region = +surf & -surfaces[i+1]
        universe.add_cell(cell)

    # Create cell for exterior of outermost ZCylinder
    cell_name = name + ' (last)'
    cell = openmc.Cell(name=cell_name, fill=materials[-1], region=+surfaces[-1])
    universe.add_cell(cell)

    # Add spacer grid cells if specified
    if grid:
        cell.region &= -surfs['spacer grid box in']

        cell_name = name + grid + ' (grid)'
        # cell = openmc.Cell(name=cell_name, region=~surfs['rod grid box'])
        cell = openmc.Cell(name=cell_name, region=+surfs['spacer grid box in'] & -surfs['spacer grid box out'])

        if grid == 'htp':
            cell.fill = mats['zircaloy']
        elif grid == 'htm':
            cell.fill = mats['inconel']
        else:
            print('Error: The "{}" is not a valid grid type" '.format(grid))
            exit(1)

        universe.add_cell(cell)

    return universe


def make_stack(name, surfaces, universes):
    """Construct a Universe of axially stacked pin cell Universes.

    Parameters
    ----------
    name: str
        The string name to assign to the Universe and each of its Cells
    surfaces: Iterable of openmc.ZPlane
        A collection of axial surfaces between which pin cells are
        filled to comprise an axially stacked pin cell
    universes: Iterable of openmc.Universe
        The Universes used within each axial layer. This collection
        must be one unit longer than the collection of surfaces.

    Returns
    -------
    universe: openmc.Universe
        The pin cell Universe
    """

    universe = openmc.Universe(name=name)

    # Create cells for each axial segment
    for i, (univ, region) in enumerate(zip(universes, subdivide(surfaces))):
        cell_name = '{} ({})'.format(name, i)
        cell = openmc.Cell(name=cell_name, fill=univ, region=region)
        universe.add_cell(cell)

    return universe


def make_pin_stack(name, zsurfaces, universes, boundary, fuel_fill):
    """Construct a Universe of axially stacked universes with a single inner fuel
    pin universe.

    Parameters
    ----------
    name: str
        The string name to assign to the Universe and each of its Cells
    zsurfaces: Iterable of openmc.ZPlane
        A collection of axial surfaces between which pin cells are
        filled to comprise an axially stacked pin cell
    universes: Iterable of openmc.Universe
        The Universes used within each axial layer. This collection
        must be one unit longer than the collection of surfaces.
    boundary : openmc.Surface
        Boundary between the fuel pin itself and everything outside (gap, clad,
        moderator)
    fuel_fill : openmc.Universe or openmc.Material
        Universe or material for (possibly subdivided) fuel

    Returns
    -------
    universe: openmc.Universe
        The pin cell Universe

    """

    universe = openmc.Universe(name=name)

    for i, (univ, r) in enumerate(zip(universes, subdivide(zsurfaces))):
        cell_name = '{} (o{})'.format(name, i)
        cell = openmc.Cell(name=cell_name, fill=univ, region=r & +boundary)
        universe.add_cell(cell)

    cell_name = '{} (i)'.format(name)
    cell = openmc.Cell(name=cell_name, fill=fuel_fill, region=-boundary)
    universe.add_cell(cell)

    return universe


def pin_universes(grid=True):
    """Generate universes for NuScale Benchmark fuel, GT, CR pins

    Returns
    ----------
    dict
        Dictionary of universes
    """

    univs = {}

    # Universes for bottom and top nozzles
    cell = openmc.Cell(name='bottom nozzle cell', fill=mats['Bottom nozzle'])
    univs['bottom nozzle'] = openmc.Universe(name='bottom nozzle universe')
    univs['bottom nozzle'].add_cell(cell)

    cell = openmc.Cell(name='top nozzle cell', fill=mats['Top nozzle'])
    univs['top nozzle'] = openmc.Universe(name='top nozzle universe')
    univs['top nozzle'].add_cell(cell)

    # Coolant universe
    cell = openmc.Cell(name='coolant cell', fill=mats['coolant'])
    univs['coolant'] = openmc.Universe(name='coolant universe')
    univs['coolant'].add_cell(cell)

    # HTP Spacer grid universe
    cell = openmc.Cell(name='HTP spacer grid cell')
    cell.region = +surfs['spacer grid box in'] & -surfs['spacer grid box out']
    cell.fill = mats['zircaloy']
    univs['HTP spacer grid'] = openmc.Universe(name='HTP spacer grid universe')
    univs['HTP spacer grid'].add_cell(cell)

    # HTM Spacer grid universe
    cell = openmc.Cell(name='HTM spacer grid cell')
    cell.region = +surfs['spacer grid box in'] & -surfs['spacer grid box out']
    cell.fill = mats['inconel']
    univs['HTM spacer grid'] = openmc.Universe(name='HTM spacer grid universe')
    univs['HTM spacer grid'].add_cell(cell)

    # Empty guide tube
    univs['GT empty'] = make_pin('GT empty',
                                 [surfs['GT IR'], surfs['GT OR']],
                                 [mats['coolant'], mats['zircaloy'], mats['coolant']])
    if grid:
        univs['GT empty htp grid'] = make_pin('GT empty',
                                 [surfs['GT IR'], surfs['GT OR']],
                                 [mats['coolant'], mats['zircaloy'], mats['coolant']],
                                    'htp')
        univs['GT empty htm grid'] = make_pin('GT empty',
                                 [surfs['GT IR'], surfs['GT OR']],
                                 [mats['coolant'], mats['zircaloy'], mats['coolant']],
                                    'htm')
        stack_surfs = [
            surfs['bottom nozzle top'],
            surfs['sg1 bottom'],
            surfs['sg1 top'],
            surfs['sg2 bottom'],
            surfs['sg2 top'],
            surfs['sg3 bottom'],
            surfs['sg3 top'],
            surfs['sg4 bottom'],
            surfs['sg4 top'],
            surfs['sg5 bottom'],
            surfs['sg5 top'],
            surfs['top nozzle bottom']
        ]

        univs['GT empty stack'] = make_stack('GT empty stack',
                                             surfaces=stack_surfs,
                                             universes=[univs['bottom nozzle'],
                                                        univs['GT empty'],
                                                        univs['GT empty htm grid'],
                                                        univs['GT empty'],
                                                        univs['GT empty htp grid'],
                                                        univs['GT empty'],
                                                        univs['GT empty htp grid'],
                                                        univs['GT empty'],
                                                        univs['GT empty htp grid'],
                                                        univs['GT empty'],
                                                        univs['GT empty htp grid'],
                                                        univs['GT empty'],
                                                        univs['top nozzle']])
    else:
        stack_surfs = [
            surfs['bottom nozzle top'],
            surfs['top nozzle bottom']
        ]

        univs['GT empty stack'] = make_stack('GT empty stack',
                                             surfaces=stack_surfs,
                                             universes=[univs['bottom nozzle'],
                                                        univs['GT empty'],
                                                        univs['top nozzle']])


    univs['end cap'] = make_pin('end cap',
                                [surfs['Cladding OR']],
                                [mats['zircaloy'], mats['coolant']])

    univs['plenum spring'] = make_pin('plenum spring',
                                      [surfs['Inconel OR'], surfs['Cladding IR'], surfs['Cladding OR']],
                                      [mats['inconel'], mats['He'], mats['zircaloy'], mats['coolant']])

    univs['A01'] = make_pin('A01',
                             [surfs['Fuel OR'], surfs['Cladding IR'], surfs['Cladding OR']],
                             [mats['A01'], mats['He'], mats['zircaloy'], mats['coolant']])

    univs['A02'] = make_pin('A02',
                             [surfs['Fuel OR'], surfs['Cladding IR'], surfs['Cladding OR']],
                             [mats['A02'], mats['He'], mats['zircaloy'], mats['coolant']])

    univs['B01'] = make_pin('B01',
                             [surfs['Fuel OR'], surfs['Cladding IR'], surfs['Cladding OR']],
                             [mats['B01'], mats['He'], mats['zircaloy'], mats['coolant']])

    univs['B02'] = make_pin('B02',
                             [surfs['Fuel OR'], surfs['Cladding IR'], surfs['Cladding OR']],
                             [mats['B02'], mats['He'], mats['zircaloy'], mats['coolant']])

    univs['C01'] = make_pin('C01',
                             [surfs['Fuel OR'], surfs['Cladding IR'], surfs['Cladding OR']],
                             [mats['C01'], mats['He'], mats['zircaloy'], mats['coolant']])

    univs['C02'] = make_pin('C02',
                             [surfs['Fuel OR'], surfs['Cladding IR'], surfs['Cladding OR']],
                             [mats['C02'], mats['He'], mats['zircaloy'], mats['coolant']])

    univs['GdC02'] = make_pin('GdC02',
                             [surfs['Fuel OR'], surfs['Cladding IR'], surfs['Cladding OR']],
                             [mats['GdC02'], mats['He'], mats['zircaloy'], mats['coolant']])

    univs['C03'] = make_pin('C03',
                             [surfs['Fuel OR'], surfs['Cladding IR'], surfs['Cladding OR']],
                             [mats['C03'], mats['He'], mats['zircaloy'], mats['coolant']])

    univs['GT CR AIC'] = make_pin('GT CR AIC',
                                  [surfs['AIC OR'], surfs['CR cladding IR'], surfs['CR cladding OR'], surfs['GT IR'], surfs['GT OR']],
                                  [mats['AIC'], mats['He'], mats['SS'], mats['coolant'], mats['zircaloy'], mats['coolant']])

    univs['GT CR B4C'] = make_pin('GT CR B4C',
                                  [surfs['B4C OR'], surfs['CR cladding IR'], surfs['CR cladding OR'], surfs['GT IR'], surfs['GT OR']],
                                  [mats['B4C'], mats['He'], mats['SS'], mats['coolant'], mats['zircaloy'], mats['coolant']])

    univs['GT CR end plug'] = make_pin('GT CR end plug',
                                       [surfs['CR cladding OR'], surfs['GT IR'], surfs['GT OR']],
                                       [mats['SS'], mats['coolant'], mats['zircaloy'], mats['coolant']])
    univs['GT CR upper plenum'] = make_pin('GT CR upper plenum',
                                           [surfs['CR cladding IR'], surfs['CR cladding OR'], surfs['GT IR'], surfs['GT OR']],
                                           [mats['He'], mats['SS'], mats['coolant'], mats['zircaloy'], mats['coolant']])

    # Guide tube with ejected/inserted control rods
    if grid:
        univs['GT CR AIC htp grid'] = make_pin('GT CR AIC htp grid',
                                               [surfs['AIC OR'], surfs['CR cladding IR'], surfs['CR cladding OR'], surfs['GT IR'], surfs['GT OR']],
                                               [mats['AIC'], mats['He'], mats['SS'], mats['coolant'], mats['zircaloy'], mats['coolant']],
                                               grid='htp')
        univs['GT CR end plug htm grid'] = make_pin('GT CR end plug htm grid',
                                                    [surfs['CR cladding OR'], surfs['GT IR'], surfs['GT OR']],
                                                    [mats['SS'], mats['coolant'], mats['zircaloy'], mats['coolant']],
                                                    grid='htm')
        univs['GT CR B4C htp grid'] = make_pin('GT CR B4C htp grid',
                                               [surfs['B4C OR'], surfs['CR cladding IR'], surfs['CR cladding OR'], surfs['GT IR'], surfs['GT OR']],
                                               [mats['B4C'], mats['He'], mats['SS'], mats['coolant'], mats['zircaloy'], mats['coolant']],
                                               grid='htp')
        univs['GT CR upper plenum htp grid'] = make_pin('GT CR upper plenum htp grid',
                                                        [surfs['CR cladding IR'], surfs['CR cladding OR'], surfs['GT IR'], surfs['GT OR']],
                                                        [mats['He'], mats['SS'], mats['coolant'], mats['zircaloy'], mats['coolant']],
                                                        grid='htp')
        stack_surfs = [
            surfs['bottom nozzle top'],
            surfs['sg1 bottom'],
            surfs['sg1 top'],
            surfs['sg2 bottom'],
            surfs['sg2 top'],
            surfs['sg3 bottom'],
            surfs['sg3 top'],
            surfs['sg4 bottom'],
            surfs['sg4 top'],
            surfs['cr ej end plug bottom'],
            surfs['cr ej aic bottom'],
            surfs['sg5 bottom'],
            surfs['sg5 top'],
            surfs['top nozzle bottom']
        ]

        univs['GT CR ejected stack'] = make_stack('GT CR ejected stack',
                                                       surfaces=stack_surfs,
                                                       universes=[univs['bottom nozzle'],
                                                        univs['GT empty'],
                                                        univs['GT empty htm grid'],
                                                        univs['GT empty'],
                                                        univs['GT empty htp grid'],
                                                        univs['GT empty'],
                                                        univs['GT empty htp grid'],
                                                        univs['GT empty'],
                                                        univs['GT empty htp grid'],
                                                        univs['GT empty'],
                                                        univs['GT CR end plug'],
                                                        univs['GT CR AIC'],
                                                        univs['GT CR AIC htp grid'],
                                                        univs['GT CR AIC'],
                                                        univs['top nozzle']])

        stack_surfs = [
            surfs['bottom nozzle top'],
            surfs['sg1 bottom'],
            surfs['cr ins end plug bottom'],
            surfs['sg1 top'],
            surfs['cr ins end plug top'],
            surfs['cr ins aic top'],
            surfs['sg2 bottom'],
            surfs['sg2 top'],
            surfs['sg3 bottom'],
            surfs['sg3 top'],
            surfs['sg4 bottom'],
            surfs['sg4 top'],
            surfs['cr ins b4c top'],
            surfs['sg5 bottom'],
            surfs['sg5 top'],
            surfs['top nozzle bottom']
        ]

        univs['GT CR inserted stack'] = make_stack('GT CR inserted stack',
                                                   surfaces=stack_surfs,
                                                   universes=[univs['bottom nozzle'],
                                                              univs['GT empty'],
                                                              univs['GT empty htm grid'],
                                                              univs['GT CR end plug htm grid'],
                                                              univs['GT CR end plug'],
                                                              univs['GT CR AIC'],
                                                              univs['GT CR B4C'],
                                                              univs['GT CR B4C htp grid'],
                                                              univs['GT CR B4C'],
                                                              univs['GT CR B4C htp grid'],
                                                              univs['GT CR B4C'],
                                                              univs['GT CR B4C htp grid'],
                                                              univs['GT CR B4C'],
                                                              univs['GT CR upper plenum'],
                                                              univs['GT CR upper plenum htp grid'],
                                                              univs['GT CR upper plenum'],
                                                              univs['top nozzle']])
    else:
        stack_surfs = [
            surfs['bottom nozzle top'],
            surfs['cr ej end plug bottom'],
            surfs['cr ej aic bottom'],
            surfs['top nozzle bottom']
        ]

        univs['GT CR ejected stack'] = make_stack('GT CR ejected stack',
                                                  surfaces=stack_surfs,
                                                  universes=[univs['bottom nozzle'],
                                                             univs['GT empty'],
                                                             univs['GT CR end plug'],
                                                             univs['GT CR AIC'],
                                                             univs['top nozzle']])

    if grid:
        stack_surfs = [
            surfs['bottom nozzle top'],
            surfs['end cap lower top'],
            surfs['sg1 bottom'],
            surfs['sg1 top'],
            surfs['sg2 bottom'],
            surfs['sg2 top'],
            surfs['sg3 bottom'],
            surfs['sg3 top'],
            surfs['sg4 bottom'],
            surfs['sg4 top'],
            surfs['fuel top'],
            surfs['sg5 bottom'],
            surfs['sg5 top'],
            surfs['plenum spring top'],
            surfs['end cap upper top'],
            surfs['top nozzle bottom']
        ]
        for fuel in ['A01', 'A02', 'B01', 'B02', 'C01', 'C02', 'GdC02', 'C03']:
            univs[fuel + ' htp grid'] = make_pin(fuel + ' htp grid',
                                             [surfs['Fuel OR'], surfs['Cladding IR'], surfs['Cladding OR']],
                                             [mats[fuel], mats['He'], mats['zircaloy'], mats['coolant']],
                                             'htp')
            univs[fuel + ' htm grid'] = make_pin(fuel + ' htm grid',
                                             [surfs['Fuel OR'], surfs['Cladding IR'], surfs['Cladding OR']],
                                             [mats[fuel], mats['He'], mats['zircaloy'], mats['coolant']],
                                             'htm')
            univs['plenum spring grid'] = make_pin('plenum spring grid',
                                                   [surfs['Inconel OR'], surfs['Cladding IR'], surfs['Cladding OR']],
                                                   [mats['inconel'], mats['He'], mats['zircaloy'], mats['coolant']],
                                                   'htp')
            univs[fuel + ' stack'] = make_stack(fuel + ' stack',
                                                surfaces=stack_surfs,
                                                universes=[univs['bottom nozzle'],
                                                           univs['end cap'],
                                                           univs[fuel],
                                                           univs[fuel + ' htm grid'],
                                                           univs[fuel],
                                                           univs[fuel + ' htp grid'],
                                                           univs[fuel],
                                                           univs[fuel + ' htp grid'],
                                                           univs[fuel],
                                                           univs[fuel + ' htp grid'],
                                                           univs[fuel],
                                                           univs['plenum spring'],
                                                           univs['plenum spring grid'],
                                                           univs['plenum spring'],
                                                           univs['end cap'],
                                                           univs['coolant'],
                                                           univs['top nozzle']])
        return univs
    else:
        stack_surfs = [
            surfs['bottom nozzle top'],
            surfs['end cap lower top'],
            surfs['fuel top'],
            surfs['plenum spring top'],
            surfs['end cap upper top'],
            surfs['top nozzle bottom']
        ]

    # A01 fuel stack
    univs['A01 stack'] = make_stack('A01 stack',
                                    surfaces=stack_surfs,
                                    universes=[univs['bottom nozzle'],
                                               univs['end cap'],
                                               univs['A01'],
                                               univs['plenum spring'],
                                               univs['end cap'],
                                               univs['coolant'],
                                               univs['top nozzle']])

    # A02 fuel stack
    univs['A02 stack'] = make_stack('A02 stack',
                                    surfaces=stack_surfs,
                                    universes=[univs['bottom nozzle'],
                                               univs['end cap'],
                                               univs['A02'],
                                               univs['plenum spring'],
                                               univs['end cap'],
                                               univs['coolant'],
                                               univs['top nozzle']])

    # B01 fuel stack
    univs['B01 stack'] = make_stack('B01 stack',
                                    surfaces=stack_surfs,
                                    universes=[univs['bottom nozzle'],
                                               univs['end cap'],
                                               univs['B01'],
                                               univs['plenum spring'],
                                               univs['end cap'],
                                               univs['coolant'],
                                               univs['top nozzle']])

    # B02 fuel stack
    univs['B02 stack'] = make_stack('B02 stack',
                                    surfaces=stack_surfs,
                                    universes=[univs['bottom nozzle'],
                                               univs['end cap'],
                                               univs['B02'],
                                               univs['plenum spring'],
                                               univs['end cap'],
                                               univs['coolant'],
                                               univs['top nozzle']])

    # C01 fuel stack
    univs['C01 stack'] = make_stack('C01 stack',
                                    surfaces=stack_surfs,
                                    universes=[univs['bottom nozzle'],
                                               univs['end cap'],
                                               univs['C01'],
                                               univs['plenum spring'],
                                               univs['end cap'],
                                               univs['coolant'],
                                               univs['top nozzle']])

    # C02 fuel stack
    univs['C02 stack'] = make_stack('C02 stack',
                                    surfaces=stack_surfs,
                                    universes=[univs['bottom nozzle'],
                                               univs['end cap'],
                                               univs['C02'],
                                               univs['plenum spring'],
                                               univs['end cap'],
                                               univs['coolant'],
                                               univs['top nozzle']])

    # C02 + Gd fuel stack
    univs['GdC02 stack'] = make_stack('GdC02 stack',
                                    surfaces=stack_surfs,
                                    universes=[univs['bottom nozzle'],
                                               univs['end cap'],
                                               univs['GdC02'],
                                               univs['plenum spring'],
                                               univs['end cap'],
                                               univs['coolant'],
                                               univs['top nozzle']])

    # C03 fuel stack
    univs['C03 stack'] = make_stack('C03 stack',
                                    surfaces=stack_surfs,
                                    universes=[univs['bottom nozzle'],
                                               univs['end cap'],
                                               univs['C03'],
                                               univs['plenum spring'],
                                               univs['end cap'],
                                               univs['coolant'],
                                               univs['top nozzle']])

    return univs
